In this demo, we present a basic demonstration of our project by training the qGAN to generate an $8 \times 8$ pixel art of smiley face. Enjoy :)
First, import python modules as follows:
import numpy as np
import plotly.io as pio
pio.renderers.default = "notebook"
Second, import our custom qGAN image generator class in the imgen.py.
from src.imgen import ImageGenerator
Now, initialze the input pixel art.
NUM_QUBITS = 6
NUM_LAYERS = 2
EPOCH_SAMPLE_SIZE = 10**4
BATCH_SAMPLE_SIZE = 10**3
pixel_art = np.array([
[0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 1, 0, 0],
[0, 0, 1, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 1, 0],
[0, 1, 0, 0, 0, 0, 1, 0],
[0, 0, 1, 1, 1, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0]
], dtype=np.float)
Initialize the qGAN class for image recognition.
imgen = ImageGenerator(
NUM_QUBITS, NUM_LAYERS,
epoch_sample_size=EPOCH_SAMPLE_SIZE, batch_sample_size=BATCH_SAMPLE_SIZE,
enable_remapping=True
)
Load the pixel art. After preprocessing with Gaussian filters, this is what your original art looks like!
imgen.load_image(pixel_art, blur_sigma=0.6, show_figure=True)
Train the qGAN and see how the output of the generator changes in 300 epochs! (Spoiler alert: you can see your smiley face again after only ~30 epochs!)
NUM_EPOCHS = 300
imgen.train(imgen.make_dataset(), NUM_EPOCHS, show_progress=True)
Training epoch 300 of 300:
Generate an interactive animation that shows the probability distribution of the circuit output at each training epoch.
imgen.plot_output_distribution_history()
Plot loss function versus steps. A trend towards convergence can be seen.
imgen.plot_loss_history()
Plot the probability distribution of the final circuit output (averaged across last 5 epochs). It matches pretty well with the original pixel face!
imgen.plot_final_output_distribution(avg_window=5)
Get the final variational parameters (averaged across last 10 steps) for the circuit.
imgen.get_final_params(avg_window=10)
array([[[-0.9736885 , -0.45710206, -0.16529681, -0.09857963,
0.01443378, -0.02832266],
[ 0.08007517, 0.03746863, 0.23725395, -0.03559735,
-0.07716206, -0.05853938]]], dtype=float32)